{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import rasterio\n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "import tifffile\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage.measure import label, regionprops\n",
    "\n",
    "target_root = \"E:/睡眠分期数据/hubmap/\"\n",
    "img_root = target_root + \"test\"\n",
    "filelist = os.listdir(img_root)\n",
    "\n",
    "def mask2rle(mask):\n",
    "    ''' takes a 2d boolean numpy array and turns it into a space-delimited RLE string '''\n",
    "    mask = mask.T.reshape(-1) # make 1D, column-first\n",
    "    mask = np.pad(mask, 1, mode=\"constant\") # make sure that the 1d mask starts and ends with a 0\n",
    "    starts = np.nonzero((~mask[:-1]) & mask[1:])[0] # start points\n",
    "    ends = np.nonzero(mask[:-1] & (~mask[1:]))[0] # end points\n",
    "    rle = np.empty(2 * starts.size, dtype=int) # interlacing...\n",
    "    rle[0::2] = starts + 1# ...starts...\n",
    "    rle[1::2] = ends - starts # ...and lengths\n",
    "    rle = ' '.join([ str(elem) for elem in rle ]) # turn into space-separated string\n",
    "    return rle\n",
    "\n",
    "def rle_decode(mask_rle, shape=(256, 256)):\n",
    "    '''\n",
    "    mask_rle: run-length as string formated (start length)\n",
    "    shape: (height,width) of array to return \n",
    "    Returns numpy array, 1 - mask, 0 - background\n",
    "\n",
    "    '''\n",
    "    s = mask_rle.split()\n",
    "    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]\n",
    "    starts -= 1\n",
    "    ends = starts + lengths\n",
    "    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)\n",
    "    for lo, hi in zip(starts, ends):\n",
    "        img[lo:hi] = 1\n",
    "    return img.reshape(shape, order='F')\n",
    "\n",
    "def step1(mask):\n",
    "    mask_re = cv2.resize(mask, (2048, 2048))\n",
    "    label_img = label(mask_re, connectivity = mask_re.ndim)\n",
    "    props = regionprops(label_img)\n",
    "    for p in props:\n",
    "        if p.area > 100:\n",
    "            continue\n",
    "        bbox = p.bbox\n",
    "        bbox = list(bbox)\n",
    "        bbox[0] = int(bbox[0]* mask.shape[0]/2048)\n",
    "        bbox[2] = int(bbox[2]* mask.shape[0]/2048)\n",
    "        bbox[1] = int(bbox[1]* mask.shape[1]/2048)\n",
    "        bbox[3] = int(bbox[3] * mask.shape[1]/2048)\n",
    "        mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 0\n",
    "    return mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正在处理 26dc41664\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\rasterio\\__init__.py:229: NotGeoreferencedWarning: Dataset has no geotransform set. The identity matrix may be returned.\n",
      "  s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正在处理 afa5e8098\n",
      "正在处理 b2dc8411c\n",
      "正在处理 b9a3865fc\n",
      "正在处理 c68fe75ea\n"
     ]
    }
   ],
   "source": [
    "csvfloder = \"./submission\"\n",
    "predcsvs = []\n",
    "for fn in os.listdir(csvfloder):\n",
    "    predcsvs.append(pd.read_csv(\"{}/{}\".format(csvfloder, fn)))\n",
    "\n",
    "out_info = {}\n",
    "identity = rasterio.Affine(1, 0, 0, 0, 1, 0)\n",
    "for id_ in predcsvs[0]['id']:\n",
    "    print(\"正在处理 {}\".format(id_))\n",
    "    dataset = rasterio.open(\"{}/{}.tiff\".format(img_root, id_), transform = identity)\n",
    "    masks = np.zeros((len(predcsvs), dataset.shape[0], dataset.shape[1]), dtype=np.float16)\n",
    "    for i, csv in  enumerate(predcsvs):\n",
    "        mask = rle_decode(csv.loc[csv['id'] == id_]['predicted'].values[0], dataset.shape)\n",
    "        masks[i] = step1(mask)\n",
    "    del mask\n",
    "    masks = np.mean(masks, axis=0, dtype=np.float16)\n",
    "    masks = (masks >= 0.5).astype(np.uint8)\n",
    "    masks = step1(masks)\n",
    "    out_info[len(out_info)] = {'id':id_, 'predicted': mask2rle(masks)}\n",
    "    del masks\n",
    "\n",
    "submission = pd.DataFrame.from_dict(out_info, orient='index')\n",
    "submission.to_csv('fixed_submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks = np.random.randn(4,512,512)\n",
    "t = np.mean(masks, axis=0, dtype=np.float16)\n",
    "t = (t >= 0.5).astype(np.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
